import random
import time

import numpy as np
import pandas as pd

import networkit as nw
from sortedcontainers import SortedKeyList

# Custom modules
import kaiwu as kw
from HiQLip.CIM import CIMSolver
from HiQLip.utils import EmbeddingCoarsening

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)


class Refinement:
    """
    Class for refining the solution of a graph partitioning problem.
    """

    def __init__(self, G, spsize, solver, solution):
        """
        Initialize the Refinement instance.

        Parameters:
        - G: networkit Graph object
        - spsize: int, size of the subproblem
        - solver: str, name of the solver to use ('cim')
        - solution: list, initial solution
        """
        self.G = G
        self.n = G.numberOfNodes()
        self.gainmap = [0 for _ in range(self.n)]
        self.spsize = spsize
        self.solver = solver
        self.solution = solution
        self.buildGain()
        self.obj = self.calc_obj(G, solution)
        self.last_subprob = None
        self.alpha = 0.2
        self.max_iterations = 3

    def buildGain(self):
        """
        Build the gain map for nodes based on the initial solution.
        """
        for u, v, w in self.G.iterEdgesWeights():
            if self.solution[u] == self.solution[v]:
                self.gainmap[u] += w
                self.gainmap[v] += w
            else:
                self.gainmap[u] -= w
                self.gainmap[v] -= w
        # Initialize gainlist with nodes sorted by gain
        self.gainlist = SortedKeyList(range(self.n), key=lambda x: self.gainmap[x] + 0.01 * x)

    def calc_obj(self, G, solution):
        """
        Calculate the objective value for the current solution.

        Parameters:
        - G: networkit Graph object
        - solution: list, current solution

        Returns:
        - obj: float, objective value
        """
        obj = 0
        for u, v in G.iterEdges():
            w = G.weight(u, v)
            obj += w * (2 * solution[u] * solution[v] - solution[u] - solution[v])
        return -1 * obj

    def refine_coarse(self):
        """
        Perform coarse refinement using the specified solver.
        """
        if self.solver == 'cim':
            solver = CIMSolver()
            # Solve the coarse problem
            self.solution, self.obj = solver.solve(self.G)
        else:
            raise ValueError(f"Unknown solver: {self.solver}")
        return self.obj

    def refine(self):
        """
        Refine the solution iteratively using subproblems.
        """
        count = 0
        while count < self.max_iterations:
            # Generate subproblem
            subprob, mapProbToSubProb, idx = self.randGainSubProb()
            if self.solver == 'cim':
                solver = CIMSolver()
                # Solve the subproblem
                S, new_obj = solver.solve(subprob)
            else:
                raise ValueError(f"Unknown solver: {self.solver}")

            new_sol = self.solution.copy()

            # Map subproblem solution back to original nodes
            for i in mapProbToSubProb:
                new_sol[i] = S[mapProbToSubProb[i]]

            # Identify changed nodes
            changed = {u for u in self.last_subprob if self.solution[u] != new_sol[u]}

            # Update objective value
            new_obj = self.obj
            for u in changed:
                for v in self.G.iterNeighbors(u):
                    if v not in changed:
                        w = self.G.weight(u, v)
                        if new_sol[u] == new_sol[v]:
                            new_obj -= w
                        else:
                            new_obj += w

            count += 1
            if new_obj >= self.obj:
                self.updateGain(new_sol, changed)
                self.solution = new_sol.copy()
                if new_obj > self.obj:
                    count = 0
                    self.obj = new_obj

    def refineLevel(self):
        """
        Perform one level of refinement.
        """
        self.refine()

    def updateGain(self, S, changed):
        """
        Update the gain map based on the changes in the solution.

        Parameters:
        - S: list, new solution
        - changed: set, nodes that have changed their assignment
        """
        for u in changed:
            for v in self.G.iterNeighbors(u):
                if v not in changed:
                    w = 2 * self.G.weight(u, v) * (1 + self.alpha)
                    if S[u] == S[v]:
                        self.gainmap[v] += w
                    else:
                        self.gainmap[v] -= w

    def randGainSubProb(self):
        """
        Generate a subproblem based on the gain map.

        Returns:
        - subprob: networkit Graph, the subproblem graph
        - mapProbToSubProb: dict, mapping from original nodes to subproblem nodes
        - idx: int, index for the supernodes in the subproblem
        """
        if self.n >= 2 * self.spsize:
            sample_size = max(int(0.2 * self.n), 2 * self.spsize)
        else:
            sample_size = self.n

        # Randomly sample nodes
        sample = random.sample(range(self.n), sample_size)
        # Sort nodes by gain
        nodes = sorted(sample, reverse=True, key=lambda x: self.gainmap[x])
        # Select top nodes for the subproblem
        spnodes = nodes[:self.spsize]

        # Create subproblem graph
        num_spnodes = len(spnodes)
        subprob = nw.graph.Graph(n=num_spnodes + 2, weighted=True, directed=False)
        mapProbToSubProb = {}
        idx = 0  # Index for nodes in subproblem

        # Map nodes to subproblem indices
        for u in spnodes:
            mapProbToSubProb[u] = idx
            idx += 1

        self.last_subprob = spnodes

        # Build subproblem edges
        for u in spnodes:
            spu = mapProbToSubProb[u]
            for v in self.G.iterNeighbors(u):
                w = self.G.weight(u, v)
                if v not in mapProbToSubProb:
                    # Connect to supernodes based on solution
                    spv = idx if self.solution[v] == 0 else idx + 1
                    subprob.increaseWeight(spu, spv, w)
                else:
                    spv = mapProbToSubProb[v]
                    if u < v:
                        subprob.increaseWeight(spu, spv, w)

        # Add edge between supernodes
        total = subprob.totalEdgeWeight()
        subprob.increaseWeight(idx, idx + 1, self.G.totalEdgeWeight() - total)
        return subprob, mapProbToSubProb, idx


class Solver:
    """
    Class to solve the graph partitioning problem using coarsening and refinement.
    """

    def __init__(self, adj, sp, solver, ratio):
        """
        Initialize the Solver instance.

        Parameters:
        - adj: numpy array, adjacency matrix
        - sp: int, size of the subproblem
        - solver: str, name of the solver to use ('cim')
        - ratio: float, coarsening ratio
        """
        self.problem_graph = nw.graph.Graph(n=adj.shape[0], weighted=True, directed=False)
        self.adj = adj
        self.hierarchy = []
        self.spsize = sp
        self.solver = solver
        self.solution = None
        self.obj = 0
        self.start = time.perf_counter()
        self.ratio = ratio

        # Build the initial graph
        for i in range(adj.shape[0]):
            for j in range(i + 1, adj.shape[0]):
                if adj[i][j] != 0:
                    self.problem_graph.addEdge(i, j, adj[i][j])

    def solve(self):
        """
        Solve the problem using a multilevel approach with coarsening and refinement.
        """
        G = self.problem_graph

        # Coarsening phase
        while G.numberOfNodes() > self.spsize:
            E = EmbeddingCoarsening(G, 4, 'sphere', self.ratio)
            E.coarsen()
            self.hierarchy.append(E)
            G = E.cG

        # Reverse the hierarchy for uncoarsening
        self.hierarchy.reverse()

        # Initial solution on the coarsest graph
        R = Refinement(
            G,
            self.spsize,
            self.solver,
            [random.randint(0, 1) for _ in range(G.numberOfNodes())],
        )
        self.coarse_obj = R.refine_coarse()
        self.obj = R.obj
        self.solution = R.solution

        # Uncoarsening and refinement phase
        for i, E in enumerate(self.hierarchy):
            G = E.G if i != len(self.hierarchy) - 1 else self.problem_graph
            fineToCoarse = E.mapFineToCoarse

            # Project solution to finer graph
            self.solution = [self.solution[fineToCoarse[j]] for j in range(G.numberOfNodes())]

            # Refinement on the finer graph
            R = Refinement(G, self.spsize, self.solver, self.solution)
            R.refineLevel()
            self.solution = R.solution
            self.obj = R.obj


def HiQLipsolver(w, userid: str = '0', sdkcode: str = '0'):
    """
    HiQLip solver function.

    Parameters:
    - w: numpy array, adjacency matrix
    - userid: str, user ID (default '0')
    - sdkcode: str, SDK code (default '0')

    Returns:
    - obj: float, objective value
    - solution: numpy array, solution vector
    """
    M = Solver(adj=w, sp=95, solver='cim', ratio=0.0)
    M.solve()
    # Adjust solution to be -1 or 1
    M.solution = np.array(M.solution) * 2 - 1
    return M.obj, M.solution
